package lic

import (
	"crypto/rand"
	"crypto/rsa"
	"crypto/x509"
	"encoding/pem"
	"fmt"
	"io"
	"os"
	"time"
)

const (
	defaultFilePath    = "./"
	PrivateKeyFileName = "private.key"
	publicKeyFileName  = "public.pem"
	LicenseFileName    = "license.lic"
)

// GenRSAKeyPair 生成公钥私钥并保存成文件
func GenRSAKeyPair(filePath string) error {
	if len(filePath) == 0 {
		filePath = defaultFilePath
	}

	rsaKey, err := rsa.GenerateKey(rand.Reader, 2048)
	if err != nil {
		return err
	}

	// 生成并保存私钥
	privatePath := fmt.Sprintf("%s/%s", filePath, PrivateKeyFileName)
	privateKeyFile, err := os.Create(privatePath)
	if err != nil {
		return err
	}
	defer privateKeyFile.Close()

	privateKeyBlock := &pem.Block{
		Type:  "RSA PRIVATE KEY",
		Bytes: x509.MarshalPKCS1PrivateKey(rsaKey),
	}
	if err = pem.Encode(privateKeyFile, privateKeyBlock); err != nil {
		return err
	}

	// 生成并保存公钥
	publicKey := rsaKey.PublicKey
	publicKeyBytes, err := x509.MarshalPKIXPublicKey(&publicKey)
	if err != nil {
		return err
	}

	publicPath := fmt.Sprintf("%s/%s", filePath, publicKeyFileName)
	publicKeyFile, err := os.Create(publicPath)
	if err != nil {
		return err
	}
	defer publicKeyFile.Close()

	pubKeyBlock := &pem.Block{
		Type:  "PUBLIC KEY",
		Bytes: publicKeyBytes,
	}
	if err = pem.Encode(publicKeyFile, pubKeyBlock); err != nil {
		return err
	}

	return nil
}

// GenEncryptLicenseFile 生成并加密license文件
func GenEncryptLicenseFile(filePath string, expirationDays int) error {
	// 读取公钥file
	publicKeyFilePath := fmt.Sprintf("%s/%s", filePath, publicKeyFileName)
	publicKeyFile, err := os.Open(publicKeyFilePath)
	if err != nil {
		return err
	}
	defer publicKeyFile.Close()

	publicKeyBytes, err := io.ReadAll(publicKeyFile)
	if err != nil {
		return err
	}

	publicKeyBlock, _ := pem.Decode(publicKeyBytes)
	if publicKeyBlock == nil || publicKeyBlock.Type != "PUBLIC KEY" {
		return fmt.Errorf("failed to decode PEM block containing public key")
	}

	publicKey, err := x509.ParsePKIXPublicKey(publicKeyBlock.Bytes)
	if err != nil {
		return err
	}

	rsaPubKey, ok := publicKey.(*rsa.PublicKey)
	if !ok {
		return fmt.Errorf("invalid RSA public key")
	}

	// 加密 失效时间
	currentTime := time.Now()
	expirationTime := currentTime.AddDate(0, 0, expirationDays)
	//expirationTime := currentTime.Add(time.Second * 11)
	expirationTimeBytes := expirationTime.Format(time.RFC3339)
	cipherText, err := rsa.EncryptPKCS1v15(rand.Reader, rsaPubKey, []byte(expirationTimeBytes))
	if err != nil {
		return err
	}
	// 将过期时间和加密后的信息组合成许可证
	licenseData := append([]byte(expirationTimeBytes), cipherText...)

	// 保存加密后的 license
	licenseFilePath := fmt.Sprintf("%s/%s", filePath, LicenseFileName)
	licenseFile, err := os.Create(licenseFilePath)
	if err != nil {
		return err
	}
	defer licenseFile.Close()
	// 写license文件数据
	_, err = licenseFile.Write(licenseData)
	if err != nil {
		return err
	}

	return nil
}

// VerifyLicense 验证license文件是否过期
func VerifyLicense(licenseFilePath, privateKeyFilePath string) (bool, error) {
	// 读取私钥文件
	privateKeyFile, err := os.Open(privateKeyFilePath)
	if err != nil {
		return false, err
	}
	defer privateKeyFile.Close()

	privateKeyBytes, err := io.ReadAll(privateKeyFile)
	if err != nil {
		return false, err
	}

	privateKeyBlock, _ := pem.Decode(privateKeyBytes)
	if privateKeyBlock == nil || privateKeyBlock.Type != "RSA PRIVATE KEY" {
		return false, fmt.Errorf("failed to decode PEM block containing private key")
	}

	// 获得私钥
	privateKey, err := x509.ParsePKCS1PrivateKey(privateKeyBlock.Bytes)
	if err != nil {
		return false, err
	}

	// 读取license文件数据
	licenseFile, err := os.Open(licenseFilePath)
	if err != nil {
		return false, err
	}
	defer licenseFile.Close()

	licenseData, err := io.ReadAll(licenseFile)
	if err != nil {
		return false, err
	}

	if len(licenseData) < 32 {
		return false, fmt.Errorf("invalid license")
	}

	// 过期时间
	expirationTimeBytes := licenseData[:len(licenseData)-256]
	// 加密后的信息
	encryptedExpiration := licenseData[len(licenseData)-256:]

	expirationTime, err := time.Parse(time.RFC3339, string(expirationTimeBytes))
	if err != nil {
		return false, err
	}

	decryptedExpiration, err := rsa.DecryptPKCS1v15(rand.Reader, privateKey, encryptedExpiration)
	if err != nil {
		return false, err
	}
	
	if string(expirationTimeBytes) != string(decryptedExpiration) {
		return false, fmt.Errorf("invalid license") // 数据被篡改
	}

	// 检查 license 有效期是否过期
	valid := time.Now().Before(expirationTime)
	return valid, nil
}
